import sys
sys.path.append("../SEV/")
import numpy as np
import pandas as pd
from FlexClustSEV import FlexClustSEV
from FCMCluster import FuzzyCMeans_base, FuzzyCMeans
from ClusterSEV import ClusterSEV
from FlexibleSEV import FlexibleSEV
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from Encoder import DataEncoder
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from TreeSEV import TreeSEV,tree_to_dict,merge_redundant_leaves,generate_path
from scipy.stats import gaussian_kde

# load the dataset
data = pd.read_csv("../../Data/fico.txt")
target = "RiskPerformance"
X = data[[i for i in data.columns if i != target]]
y = data[target]
# y = np.array(y)
X_neg = X[y==0]

# encode the data
encoder = DataEncoder(standard=True)
encoder.fit(X_neg)
encoded_X = encoder.transform(X)
encoded_X_neg = encoder.transform(X_neg)

# read in the explanation sample
explain_df = pd.read_csv("Explanations.csv").set_index("Unnamed: 0")
explained_features = []
for i in range(1,explain_df.shape[0]):
    # collection all non-zero columns for each row
    non_zero_columns = explain_df.columns[explain_df.iloc[i]>1e-3]
    explained_features.append(non_zero_columns)
# set all values less than 0.01 to np.nan
explain_df[explain_df<1e-3] = np.nan
# fill the nan values with the first row
explain_df = explain_df.fillna(explain_df.iloc[0])
explain_df = explain_df.fillna(0)
encoded_explain_df = encoder.transform(explain_df)

# do a pacmap visualization
import pacmap
import matplotlib.pyplot as plt

# fit the pacmap
embedding = pacmap.PaCMAP(n_components=2, n_neighbors=None, MN_ratio=1, FP_ratio=2.0,random_state=42)
encoded_X_neg_map =  embedding.fit_transform(encoded_X_neg,init="pca")
encoded_X_map = embedding.transform(encoded_X,encoded_X_neg)
encoded_explain_df_map = embedding.transform(encoded_explain_df,encoded_X_neg)

labels = ["SEV 1","SEV_C", "SEV F", "SEV C+F", "SEV T"]

# do a train test split
X_train, X_test, y_train, y_test = train_test_split(encoded_X, y, test_size=0.2, random_state=42)

# fit the model
model = LogisticRegression(solver='liblinear',penalty='l2',C=1e-2)
model.fit(X_train,y_train)
tree_model = DecisionTreeClassifier(max_depth=5,random_state=42)
tree_model.fit(X_train,y_train)

# plot the samples with their predict_proba

plt.figure(figsize=(15,5))
plt.subplot(1,3,1)
plt.scatter(encoded_X_map[:,0],encoded_X_map[:,1],c=model.predict_proba(encoded_X)[:,1],cmap="coolwarm",s=10)
plt.xticks([])
plt.yticks([])
# add the name for the colorbar
cbar = plt.colorbar()
cbar.set_label('Predicted Probability', rotation=270)
plt.tight_layout()

plt.subplot(1,3,2)
plt.scatter(encoded_X_map[:,0],encoded_X_map[:,1],c="gray",s=10,alpha=0.2)
plt.xticks([])
plt.yticks([])
import matplotlib.cm as cm
colors = cm.Accent(np.linspace(0, 1, len(labels)))
for i in range(1,encoded_explain_df_map.shape[0]):
    idx = i-1
    plt.scatter(encoded_explain_df_map[i,0],encoded_explain_df_map[i,1],label=labels[idx],color=colors[idx],s=10)
plt.scatter(encoded_explain_df_map[0,0],encoded_explain_df_map[0,1],c='red',label="Query",marker="*")
plt.legend()
plt.subplot(1,3,3)
selected_X_map = encoded_X_map[(encoded_X_map[:,0]>np.min(encoded_explain_df_map[:,0])-0.1)&(encoded_X_map[:,0]<np.max(encoded_explain_df_map[:,0])+0.5)&(encoded_X_map[:,1]>np.min(encoded_explain_df_map[:,1])-0.1)&(encoded_X_map[:,1]<np.max(encoded_explain_df_map[:,1])+0.5)]
print(encoded_X_map.shape)
print(selected_X_map.shape)
selected_X = encoded_X[(encoded_X_map[:,0]>np.min(encoded_explain_df_map[:,0])-0.1)&(encoded_X_map[:,0]<np.max(encoded_explain_df_map[:,0])+0.5)&(encoded_X_map[:,1]>np.min(encoded_explain_df_map[:,1])-0.1)&(encoded_X_map[:,1]<np.max(encoded_explain_df_map[:,1])+0.5)]
print(selected_X.shape)
print(encoded_X.shape)
# plt.scatter(selected_X_map[:,0],selected_X_map[:,1],c=model.predict_proba(selected_X)[:,1],cmap="coolwarm",s=1000,alpha=0.1)

kde = gaussian_kde(selected_X_map.T, bw_method='silverman', weights=model.predict_proba(selected_X)[:,1])


xmin, xmax = selected_X_map[:,0].min(), selected_X_map[:,0].max()
ymin, ymax = selected_X_map[:,1].min(), selected_X_map[:,1].max()
xx, yy = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
positions = np.vstack([xx.ravel(), yy.ravel()])


f = np.reshape(kde(positions).T, xx.shape)

plt.imshow(np.rot90(f), cmap='coolwarm', extent=[xmin, xmax, ymin, ymax],alpha=0.2)

import matplotlib.cm as cm
for i in range(1,encoded_explain_df_map.shape[0]):
    idx = i-1
    if idx != 1:
        plt.scatter(encoded_explain_df_map[i,0],encoded_explain_df_map[i,1],label=labels[idx],color=colors[idx],s=30)
    else:
        plt.scatter(encoded_explain_df_map[i,0],encoded_explain_df_map[i,1],label=labels[idx],color=["purple"],s=30,marker="*")
    query = encoded_explain_df.iloc[0].copy()
    reference = encoded_explain_df.iloc[i].copy()
    current_point = encoded_explain_df_map[0]
    for feature in explained_features[idx]:
        # replace the value with the value in the reference
        query[feature] = reference[feature]
        # do a embedding to the query
        query_encoded_map = embedding.transform(query.values.reshape(1,-1),encoded_X_neg)
        # plot the point and draw a line between current  point and enocded point
        if idx != 1:
            plt.scatter(query_encoded_map[0,0],query_encoded_map[0,1],c=colors[idx],s=30,alpha=0.5)
        else:
            plt.scatter(query_encoded_map[0,0],query_encoded_map[0,1],c="purple",s=30,marker="*",alpha=0.5)
        if idx != 1:
            # plot the arrow from current to query
            plt.arrow(current_point[0], current_point[1], query_encoded_map[0,0] - current_point[0], query_encoded_map[0,1] - current_point[1], 
                      color=colors[idx], linestyle="--", head_width=0.1, head_length=0.1)
        else:
            plt.arrow(current_point[0], current_point[1], query_encoded_map[0,0] - current_point[0], query_encoded_map[0,1] - current_point[1], 
                      color="purple", linestyle="--", head_width=0.1, head_length=0.1)
        #plt.plot([current_point[0],query_encoded_map[0,0]],[current_point[1],query_encoded_map[0,1]],c=colors[idx],linestyle="--",)
        current_point = query_encoded_map[0]



plt.scatter(encoded_explain_df_map[0,0],encoded_explain_df_map[0,1],c='red',label="Query",marker="*",s=100)
plt.xticks([])
plt.yticks([])
plt.legend()
# save the image
plt.savefig("Explanations.png")


# # evaluate the model
# y_pred_train = model.predict(X_train)
# y_pred_test = model.predict(X_test)

# # build different SEV methods
# originalSEV = FlexibleSEV(model,encoder,X.columns,encoded_X_neg,tol=0,k=1)
# # build flexible SEV method
# flexibleSEV = FlexibleSEV(model,encoder,X.columns,encoded_X_neg,tol=0.2,k=5)
# # build cluster SEV method
# clusterSEV = ClusterSEV(model,encoder,encoded_X.columns,encoded_X_neg,n_clusters=3,m=3)
# # build flexclust SEV method
# flexclustSEV = FlexClustSEV(model,encoder,encoded_X.columns,encoded_X_neg,n_clusters=3,m=3,tolerance=0.2,k=5)
# # build the treeSEV method
# tree_sev = TreeSEV(tree_model,X_test)

# print(X_train)

# tree_dict = tree_to_dict(tree_model,X_test.columns)
# tree_dict = merge_redundant_leaves(tree_dict)


# X_embedded = clusterSEV.embedding.transform(X_test,encoded_X_neg)